"""
Main program to generate SDMP
"""

# pylint: disable=anomalous-backslash-in-string
# pylint: disable=invalid-name
# pylint: disable=missing-function-docstring
import os
import sys
import argparse
import shutil
import warnings

from data_utils import SDMPDataPre
from graph_dict import SDMP, wrap_log_for_fig
from utils import load_sdmp_conf_with_default, Logger, export_train_conf

warnings.filterwarnings('ignore')
os.environ["CURL_CA_BUNDLE"] = ""

## Parse command line parameters
_parser = argparse.ArgumentParser(
    description="Argument of the main.")

_parser.add_argument("-c", "--config", default="config/cora/SDMP/cora_SDMP_base.yml",
                     type=str, help="Path to the configuration file. ")
_parser.add_argument("-d", "--data", default=os.path.dirname(os.path.abspath(__file__))+"/dataset",
                     type=str, help="Path the root datafolder. ")
_parser.add_argument("-r", "--result", default="result/tmp", type=str,
                     help="Path to the result folder. ")
_parser.add_argument("-g", "--gnn", default="result/cora/SAGE/seed_123", type=str,
                     help="Path to the target GNN folder. ")
_parser.add_argument("-i", "--device", default="cuda:4", type=str,
                     help="Device name to run pytorch code. ")
_parser.add_argument("--eval_level",
                     type=int,
                     default=0,
                     help="""
                          0 for fast eval.
                          1 for full eval at begining only.
                          2 for full eval at every evaluation. 
                          """)

ARGS_GLOBAL = _parser.parse_args()


DEVICE = ARGS_GLOBAL.device
DATA_ROOT_FOLDER = ARGS_GLOBAL.data
CONF_PATH = ARGS_GLOBAL.config
RES_FOLDER = ARGS_GLOBAL.result
TARGET_GNN_FOLDER = ARGS_GLOBAL.gnn

## Initialization and hyper-parameter setting
train_conf = load_sdmp_conf_with_default(CONF_PATH)

DATA_FOLDER = os.path.join(DATA_ROOT_FOLDER, train_conf["name"])

if not os.path.exists(RES_FOLDER):
    os.makedirs(RES_FOLDER)

sys.stdout = Logger(os.path.join(RES_FOLDER, "log.txt"))

for k in vars(ARGS_GLOBAL):
    print(k, getattr(ARGS_GLOBAL, k))

print(train_conf)
## Preprocessing
# Data loading
# g = load_data(train_conf["name"])

# Preprocessing the features

GNN_MODEL_PATH = os.path.join(TARGET_GNN_FOLDER, train_conf['target_h_model_path'])
GNN_CONF_PATH = os.path.join(TARGET_GNN_FOLDER, train_conf['target_h_model_conf_path'])
GNN_ACC_PATH = os.path.join(TARGET_GNN_FOLDER, train_conf['target_h_model_metric_path'])
GNN_DATA_SPLIT_PATH = os.path.join(TARGET_GNN_FOLDER, "data_split_seed.txt")

with open(GNN_DATA_SPLIT_PATH, 'r') as fin:
    print(f"Target GNN with split seed {fin.read()}.")

# data loading and prepare the data
preprocesser = SDMPDataPre(train_conf["name"], train_conf["feature_normalize"],
                           train_conf["target_h_mode"],
                           GNN_CONF_PATH, GNN_MODEL_PATH, train_conf["target_h_model"], 
                           train_conf["h_init_theta_mode"], train_conf["h_init_theta_k"],
                           train_conf["h_init_theta_k_fanout"],
                           train_conf["theta_cand_mode"], train_conf["theta_cand_k2"],
                           train_conf["theta_cand_k1"], train_conf["theta_cand_fanout"],
                           train_conf["theta_cand_add_self"],
                           train_conf,
                           use_cache=True, cache_path=os.path.join(DATA_FOLDER, "SDMPPre"),
                           device=DEVICE)
preprocesser.disp_states()
theta_cand, h_init_theta, X, target =\
    preprocesser.theta_cand, preprocesser.h_init_theta, preprocesser.X, preprocesser.target

## Main algorithm

print("Initializing the model...")
test = SDMP(X,
            target,
            theta_cand,
            h_init_theta,
            train_conf,
            device=DEVICE,
            verbose=True)

print("Starting to fit...")
test.fit(eval_level=ARGS_GLOBAL.eval_level)

# save results
print("Saving the results...")
test.save(RES_FOLDER)
export_train_conf(os.path.join(RES_FOLDER, 'conf.yml'), train_conf)
shutil.copyfile(GNN_MODEL_PATH, os.path.join(RES_FOLDER, 'GNN_target_state'))
if train_conf["target_h_mode"] == "internal":
    shutil.copyfile(GNN_CONF_PATH, os.path.join(RES_FOLDER, 'GNN_conf.yml'))
shutil.copyfile(GNN_ACC_PATH, os.path.join(RES_FOLDER, 'GNN_f1.txt'))
shutil.copyfile(GNN_DATA_SPLIT_PATH, os.path.join(RES_FOLDER, 'GNN_data_split_seed.txt'))

## for research purpose
if ARGS_GLOBAL.eval_level > 0:
    wrap_log_for_fig(test.log, train_conf, RES_FOLDER)
